"""Metric utilities for Simulation B.

This module implements functions to compute the root mean squared error
between measured visibilities and predicted visibilities, as well as
confidence intervals via bootstrapping.  These routines are used by
the analysis and plotting scripts to assess agreement with theory.
"""

from __future__ import annotations

from typing import Iterable, Tuple, Sequence

import numpy as np


def rmse(y_true: Sequence[float], y_pred: Sequence[float]) -> float:
    """Return the root mean squared error between two sequences."""
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    if y_true.shape != y_pred.shape:
        raise ValueError("Shape mismatch in RMSE computation")
    return float(np.sqrt(np.mean((y_true - y_pred) ** 2)))


def bootstrap_ci(data: Iterable[float], ci: float = 0.68, n_bootstrap: int = 1000) -> Tuple[float, float]:
    """Return a central confidence interval around the mean of ``data`` via bootstrapping.

    Parameters
    ----------
    data : iterable of floats
    ci : float, default 0.68
        Fraction of mass inside the confidence interval.  For 68 % CI use 0.68; for 95 % use 0.95.
    n_bootstrap : int
        Number of bootstrap resamples.

    Returns
    -------
    (low, high) : tuple of floats
        Lower and upper bounds of the central confidence interval.
    """
    arr = np.asarray(list(data), dtype=float)
    if arr.size == 0:
        return (0.0, 0.0)
    rng = np.random.default_rng(12345)
    means = []
    for _ in range(n_bootstrap):
        sample = rng.choice(arr, size=arr.size, replace=True)
        means.append(sample.mean())
    means = np.sort(means)
    low_idx = int(((1.0 - ci) / 2.0) * n_bootstrap)
    high_idx = int((1.0 - (1.0 - ci) / 2.0) * n_bootstrap) - 1
    low = float(means[low_idx])
    high = float(means[high_idx])
    return (low, high)